#
#  Copyright 2021 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import base64
import datetime
import io
import json
import os
import pickle
import socket
import time
import typing
import uuid
from enum import Enum, IntEnum
from functools import wraps

from pyoi.types import BaseType
from pyoi.util.log_utils import getLogger

LOGGER = getLogger()


class CustomJSONEncoder(json.JSONEncoder):
    def __init__(self, **kwargs):
        self._with_type = kwargs.pop("with_type", False)
        super().__init__(**kwargs)

    def default(self, obj):
        if isinstance(obj, datetime.datetime):
            return obj.strftime('%Y-%m-%d %H:%M:%S')
        elif isinstance(obj, datetime.date):
            return obj.strftime('%Y-%m-%d')
        elif isinstance(obj, datetime.timedelta):
            return str(obj)
        elif issubclass(type(obj), Enum):
            return obj.name
        elif issubclass(type(obj), IntEnum):
            return obj.value
        elif isinstance(obj, set):
            return list(obj)
        elif issubclass(type(obj), BaseType):
            if not self._with_type:
                return obj.to_dict()
            else:
                return obj.to_dict_with_type()
        elif isinstance(obj, type):
            return obj.__name__
        else:
            return json.JSONEncoder.default(self, obj)


def oi_uuid():
    return uuid.uuid1().hex


def get_object_classpath(obj):
    cls = type(obj)
    module = cls.__module__
    name = cls.__qualname__
    if module is not None and module != "__builtin__":
        name = module + "." + name
    return name


def string_to_bytes(string) -> bytes:
    return string if isinstance(string, bytes) else string.encode(encoding="utf-8")


def bytes_to_string(src: bytes):
    return src.decode(encoding="utf-8")


def json_dumps(src, byte=False, indent=None, with_type=False) -> typing.Union[str, bytes]:
    dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
    if byte:
        dest = string_to_bytes(dest)
    return dest


def json_loads(src: typing.Union[str, bytes], object_hook=None, object_pairs_hook=None):
    if isinstance(src, bytes):
        src = bytes_to_string(src)
    return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)


def current_timestamp():
    return int(time.time() * 1000)


def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
    timestamp = int(timestamp) / 1000
    time_array = time.localtime(timestamp)
    str_date = time.strftime(format_string, time_array)
    return str_date


def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
    time_array = time.strptime(time_str, format_string)
    time_stamp = int(time.mktime(time_array) * 1000)
    return time_stamp


def serialize_b64(src, to_str=False):
    dest = base64.b64encode(pickle.dumps(src))
    if not to_str:
        return dest
    else:
        return bytes_to_string(dest)


def deserialize_b64(src, using_safe_module=False):
    src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
    if using_safe_module:
        return restricted_loads(src)
    return pickle.loads(src)


def serialize_hex(src):
    return pickle.dumps(src).hex()


def deserialize_hex(src):
    return pickle.loads(bytes.fromhex(src))


safe_module = {
    'pyoi',
    'numpy'
}


class RestrictedUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        import importlib
        if module.split('.')[0] in safe_module:
            _module = importlib.import_module(module)
            return getattr(_module, name)
        # Forbid everything else.
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
                                     (module, name))


def restricted_loads(src):
    """Helper function analogous to pickle.loads()."""
    return RestrictedUnpickler(io.BytesIO(src)).load()


def trys(times=10, interval=0.1, set_retries=False):
    def wrapper(func):
        @wraps(func)
        def decorated_function(*args, **kwargs):
            for n in range(times):
                try:
                    if set_retries:
                        kwargs["retries"] = n
                    return func(*args, **kwargs)
                except Exception as e:
                    if n + 1 < times:
                        LOGGER.warning(f"function {func.__name__} error, try again {n} times", exc_info=True)
                    else:
                        LOGGER.error(f"function {func.__name__} error, no retries", exc_info=True)
                        raise e
                time.sleep(interval * n)

        return decorated_function

    return wrapper


def get_lan_ip():
    if os.name != "nt":
        import fcntl
        import struct

        def get_interface_ip(ifname):
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            return socket.inet_ntoa(
                fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])

    ip = socket.gethostbyname(socket.getfqdn())
    if ip.startswith("127.") and os.name != "nt":
        interfaces = [
            "bond1",
            "eth0",
            "eth1",
            "eth2",
            "wlan0",
            "wlan1",
            "wifi0",
            "ath0",
            "ath1",
            "ppp0",
        ]
        for ifname in interfaces:
            try:
                ip = get_interface_ip(ifname)
                break
            except IOError as e:
                pass
    return ip or ''
